{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "deepfm_movielens.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "xCGRpOiK3_Wo",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 107
        },
        "outputId": "edb80c12-de4c-4ba4-d7e4-623c74af558f"
      },
      "source": [
        "!pip install pandas\n"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: pandas in /usr/local/lib/python3.6/dist-packages (1.0.5)\n",
            "Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas) (2.8.1)\n",
            "Requirement already satisfied: numpy>=1.13.3 in /usr/local/lib/python3.6/dist-packages (from pandas) (1.18.5)\n",
            "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas) (2018.9)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.6/dist-packages (from python-dateutil>=2.6.1->pandas) (1.15.0)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "M5ikOM_j4wmj",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 107
        },
        "outputId": "2b9b9334-f65b-4f82-b4b8-3487621b5d53"
      },
      "source": [
        "!pip install sklearn\n"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: sklearn in /usr/local/lib/python3.6/dist-packages (0.0)\n",
            "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.6/dist-packages (from sklearn) (0.22.2.post1)\n",
            "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn->sklearn) (0.16.0)\n",
            "Requirement already satisfied: scipy>=0.17.0 in /usr/local/lib/python3.6/dist-packages (from scikit-learn->sklearn) (1.4.1)\n",
            "Requirement already satisfied: numpy>=1.11.0 in /usr/local/lib/python3.6/dist-packages (from scikit-learn->sklearn) (1.18.5)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wDNFAcd-4yMn",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 179
        },
        "outputId": "1a55aa86-5a29-4382-a158-867d13f409b2"
      },
      "source": [
        "!pip install deepctr\n"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: deepctr in /usr/local/lib/python3.6/dist-packages (0.8.1)\n",
            "Requirement already satisfied: h5py in /usr/local/lib/python3.6/dist-packages (from deepctr) (2.10.0)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from deepctr) (2.23.0)\n",
            "Requirement already satisfied: numpy>=1.7 in /usr/local/lib/python3.6/dist-packages (from h5py->deepctr) (1.18.5)\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from h5py->deepctr) (1.15.0)\n",
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->deepctr) (1.24.3)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->deepctr) (2020.6.20)\n",
            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->deepctr) (3.0.4)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->deepctr) (2.10)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AzdnocDq41IU",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 683
        },
        "outputId": "82096cc7-d3a8-4c7a-f1f2-e2e32e5713bf"
      },
      "source": [
        "!pip install tensorflow"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: tensorflow in /usr/local/lib/python3.6/dist-packages (2.3.0)\n",
            "Requirement already satisfied: google-pasta>=0.1.8 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (0.2.0)\n",
            "Requirement already satisfied: numpy<1.19.0,>=1.16.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (1.18.5)\n",
            "Requirement already satisfied: wrapt>=1.11.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (1.12.1)\n",
            "Requirement already satisfied: astunparse==1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (1.6.3)\n",
            "Requirement already satisfied: h5py<2.11.0,>=2.10.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (2.10.0)\n",
            "Requirement already satisfied: gast==0.3.3 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (0.3.3)\n",
            "Requirement already satisfied: keras-preprocessing<1.2,>=1.1.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (1.1.2)\n",
            "Requirement already satisfied: tensorflow-estimator<2.4.0,>=2.3.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (2.3.0)\n",
            "Requirement already satisfied: scipy==1.4.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (1.4.1)\n",
            "Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (3.3.0)\n",
            "Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (1.32.0)\n",
            "Requirement already satisfied: protobuf>=3.9.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (3.12.4)\n",
            "Requirement already satisfied: absl-py>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (0.10.0)\n",
            "Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (0.35.1)\n",
            "Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (1.15.0)\n",
            "Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (1.1.0)\n",
            "Requirement already satisfied: tensorboard<3,>=2.3.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (2.3.0)\n",
            "Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.9.2->tensorflow) (50.3.0)\n",
            "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard<3,>=2.3.0->tensorflow) (1.7.0)\n",
            "Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard<3,>=2.3.0->tensorflow) (1.17.2)\n",
            "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tensorboard<3,>=2.3.0->tensorflow) (0.4.1)\n",
            "Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard<3,>=2.3.0->tensorflow) (2.23.0)\n",
            "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard<3,>=2.3.0->tensorflow) (3.2.2)\n",
            "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard<3,>=2.3.0->tensorflow) (1.0.1)\n",
            "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard<3,>=2.3.0->tensorflow) (0.2.8)\n",
            "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard<3,>=2.3.0->tensorflow) (4.1.1)\n",
            "Requirement already satisfied: rsa<5,>=3.1.4; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard<3,>=2.3.0->tensorflow) (4.6)\n",
            "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard<3,>=2.3.0->tensorflow) (1.3.0)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<3,>=2.3.0->tensorflow) (2.10)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<3,>=2.3.0->tensorflow) (2020.6.20)\n",
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<3,>=2.3.0->tensorflow) (1.24.3)\n",
            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<3,>=2.3.0->tensorflow) (3.0.4)\n",
            "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tensorboard<3,>=2.3.0->tensorflow) (1.7.0)\n",
            "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.6/dist-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard<3,>=2.3.0->tensorflow) (0.4.8)\n",
            "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<3,>=2.3.0->tensorflow) (3.1.0)\n",
            "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < \"3.8\"->markdown>=2.6.8->tensorboard<3,>=2.3.0->tensorflow) (3.1.0)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RA0ND1qk4Ako",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "outputId": "c8696625-1bb3-4f6d-c73c-99a667fd3df6"
      },
      "source": [
        "import pandas as pd\n",
        "from sklearn.metrics import mean_squared_error\n",
        "from sklearn.model_selection import train_test_split\n",
        "from sklearn.preprocessing import LabelEncoder\n",
        "from deepctr.models import DeepFM\n",
        "from deepctr.feature_column import SparseFeat,get_feature_names\n",
        "\n",
        "#数据加载\n",
        "data = pd.read_csv(\"./drive/My Drive/movielens_sample.txt\")\n",
        "sparse_features = [\"movie_id\", \"user_id\", \"gender\", \"age\", \"occupation\", \"zip\"]\n",
        "target = ['rating']\n",
        "\n",
        "# 对特征标签进行编码\n",
        "for feature in sparse_features:\n",
        "    lbe = LabelEncoder()\n",
        "    data[feature] = lbe.fit_transform(data[feature])\n",
        "\n",
        "# 计算每个特征中的 不同特征值的个数\n",
        "fixlen_feature_columns = [SparseFeat(feature, data[feature].nunique()) for feature in sparse_features]\n",
        "print(fixlen_feature_columns)\n",
        "linear_feature_columns = fixlen_feature_columns\n",
        "dnn_feature_columns = fixlen_feature_columns\n",
        "feature_names = get_feature_names(linear_feature_columns + dnn_feature_columns)\n",
        "\n",
        "# 将数据集切分成训练集和测试集\n",
        "train, test = train_test_split(data, test_size=0.2)\n",
        "train_model_input = {name:train[name].values for name in feature_names}\n",
        "test_model_input = {name:test[name].values for name in feature_names}\n",
        "\n",
        "# 使用DeepFM进行训练\n",
        "model = DeepFM(linear_feature_columns, dnn_feature_columns, task='regression')\n",
        "model.compile(\"adam\", \"mse\", metrics=['mse'], )\n",
        "history = model.fit(train_model_input, train[target].values, batch_size=256, epochs=50, verbose=True, validation_split=0.2, )\n",
        "# 使用DeepFM进行预测\n",
        "pred_ans = model.predict(test_model_input, batch_size=256)\n",
        "# 输出RMSE或MSE\n",
        "mse = round(mean_squared_error(test[target].values, pred_ans), 4)\n",
        "rmse = mse ** 0.5\n",
        "print(\"test RMSE\", rmse)"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[SparseFeat(name='movie_id', vocabulary_size=187, embedding_dim=4, use_hash=False, dtype='int32', embeddings_initializer=<tensorflow.python.keras.initializers.initializers_v1.RandomNormal object at 0x7f4c1d37c828>, embedding_name='movie_id', group_name='default_group', trainable=True), SparseFeat(name='user_id', vocabulary_size=193, embedding_dim=4, use_hash=False, dtype='int32', embeddings_initializer=<tensorflow.python.keras.initializers.initializers_v1.RandomNormal object at 0x7f4c1d37ca20>, embedding_name='user_id', group_name='default_group', trainable=True), SparseFeat(name='gender', vocabulary_size=2, embedding_dim=4, use_hash=False, dtype='int32', embeddings_initializer=<tensorflow.python.keras.initializers.initializers_v1.RandomNormal object at 0x7f4c1d37cb38>, embedding_name='gender', group_name='default_group', trainable=True), SparseFeat(name='age', vocabulary_size=7, embedding_dim=4, use_hash=False, dtype='int32', embeddings_initializer=<tensorflow.python.keras.initializers.initializers_v1.RandomNormal object at 0x7f4c1d37c5c0>, embedding_name='age', group_name='default_group', trainable=True), SparseFeat(name='occupation', vocabulary_size=20, embedding_dim=4, use_hash=False, dtype='int32', embeddings_initializer=<tensorflow.python.keras.initializers.initializers_v1.RandomNormal object at 0x7f4c1d37c7f0>, embedding_name='occupation', group_name='default_group', trainable=True), SparseFeat(name='zip', vocabulary_size=188, embedding_dim=4, use_hash=False, dtype='int32', embeddings_initializer=<tensorflow.python.keras.initializers.initializers_v1.RandomNormal object at 0x7f4c1d37cf28>, embedding_name='zip', group_name='default_group', trainable=True)]\n",
            "Epoch 1/50\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/indexed_slices.py:432: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.\n",
            "  \"Converting sparse IndexedSlices to a dense Tensor of unknown shape. \"\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "1/1 [==============================] - 0s 304ms/step - loss: 13.7656 - mse: 13.7656 - val_loss: 15.1093 - val_mse: 15.1093\n",
            "Epoch 2/50\n",
            "1/1 [==============================] - 0s 26ms/step - loss: 13.6400 - mse: 13.6400 - val_loss: 14.9949 - val_mse: 14.9949\n",
            "Epoch 3/50\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 13.5067 - mse: 13.5067 - val_loss: 14.8728 - val_mse: 14.8728\n",
            "Epoch 4/50\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 13.3646 - mse: 13.3646 - val_loss: 14.7429 - val_mse: 14.7429\n",
            "Epoch 5/50\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 13.2135 - mse: 13.2135 - val_loss: 14.6046 - val_mse: 14.6046\n",
            "Epoch 6/50\n",
            "1/1 [==============================] - 0s 26ms/step - loss: 13.0530 - mse: 13.0530 - val_loss: 14.4581 - val_mse: 14.4581\n",
            "Epoch 7/50\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 12.8827 - mse: 12.8827 - val_loss: 14.3027 - val_mse: 14.3027\n",
            "Epoch 8/50\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 12.7020 - mse: 12.7020 - val_loss: 14.1380 - val_mse: 14.1380\n",
            "Epoch 9/50\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 12.5105 - mse: 12.5105 - val_loss: 13.9633 - val_mse: 13.9633\n",
            "Epoch 10/50\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 12.3076 - mse: 12.3076 - val_loss: 13.7781 - val_mse: 13.7781\n",
            "Epoch 11/50\n",
            "1/1 [==============================] - 0s 29ms/step - loss: 12.0925 - mse: 12.0925 - val_loss: 13.5817 - val_mse: 13.5817\n",
            "Epoch 12/50\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 11.8642 - mse: 11.8642 - val_loss: 13.3735 - val_mse: 13.3735\n",
            "Epoch 13/50\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 11.6221 - mse: 11.6221 - val_loss: 13.1527 - val_mse: 13.1527\n",
            "Epoch 14/50\n",
            "1/1 [==============================] - 0s 28ms/step - loss: 11.3653 - mse: 11.3653 - val_loss: 12.9187 - val_mse: 12.9187\n",
            "Epoch 15/50\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 11.0929 - mse: 11.0929 - val_loss: 12.6708 - val_mse: 12.6708\n",
            "Epoch 16/50\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 10.8041 - mse: 10.8041 - val_loss: 12.4082 - val_mse: 12.4082\n",
            "Epoch 17/50\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 10.4982 - mse: 10.4982 - val_loss: 12.1302 - val_mse: 12.1302\n",
            "Epoch 18/50\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 10.1743 - mse: 10.1743 - val_loss: 11.8362 - val_mse: 11.8362\n",
            "Epoch 19/50\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 9.8322 - mse: 9.8322 - val_loss: 11.5260 - val_mse: 11.5259\n",
            "Epoch 20/50\n",
            "1/1 [==============================] - 0s 26ms/step - loss: 9.4714 - mse: 9.4714 - val_loss: 11.1989 - val_mse: 11.1989\n",
            "Epoch 21/50\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 9.0918 - mse: 9.0918 - val_loss: 10.8549 - val_mse: 10.8548\n",
            "Epoch 22/50\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 8.6934 - mse: 8.6934 - val_loss: 10.4935 - val_mse: 10.4935\n",
            "Epoch 23/50\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 8.2763 - mse: 8.2763 - val_loss: 10.1147 - val_mse: 10.1147\n",
            "Epoch 24/50\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 7.8412 - mse: 7.8412 - val_loss: 9.7187 - val_mse: 9.7187\n",
            "Epoch 25/50\n",
            "1/1 [==============================] - 0s 26ms/step - loss: 7.3886 - mse: 7.3886 - val_loss: 9.3061 - val_mse: 9.3061\n",
            "Epoch 26/50\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 6.9199 - mse: 6.9199 - val_loss: 8.8776 - val_mse: 8.8776\n",
            "Epoch 27/50\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 6.4371 - mse: 6.4371 - val_loss: 8.4344 - val_mse: 8.4344\n",
            "Epoch 28/50\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 5.9428 - mse: 5.9428 - val_loss: 7.9776 - val_mse: 7.9775\n",
            "Epoch 29/50\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 5.4404 - mse: 5.4403 - val_loss: 7.5091 - val_mse: 7.5091\n",
            "Epoch 30/50\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 4.9337 - mse: 4.9337 - val_loss: 7.0316 - val_mse: 7.0316\n",
            "Epoch 31/50\n",
            "1/1 [==============================] - 0s 29ms/step - loss: 4.4278 - mse: 4.4278 - val_loss: 6.5480 - val_mse: 6.5480\n",
            "Epoch 32/50\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 3.9282 - mse: 3.9282 - val_loss: 6.0620 - val_mse: 6.0619\n",
            "Epoch 33/50\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 3.4418 - mse: 3.4418 - val_loss: 5.5776 - val_mse: 5.5776\n",
            "Epoch 34/50\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 2.9759 - mse: 2.9759 - val_loss: 5.1000 - val_mse: 5.1000\n",
            "Epoch 35/50\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 2.5390 - mse: 2.5389 - val_loss: 4.6347 - val_mse: 4.6347\n",
            "Epoch 36/50\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 2.1396 - mse: 2.1396 - val_loss: 4.1876 - val_mse: 4.1875\n",
            "Epoch 37/50\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 1.7869 - mse: 1.7869 - val_loss: 3.7656 - val_mse: 3.7656\n",
            "Epoch 38/50\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 1.4892 - mse: 1.4891 - val_loss: 3.3757 - val_mse: 3.3757\n",
            "Epoch 39/50\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 1.2536 - mse: 1.2535 - val_loss: 3.0244 - val_mse: 3.0244\n",
            "Epoch 40/50\n",
            "1/1 [==============================] - 0s 38ms/step - loss: 1.0850 - mse: 1.0850 - val_loss: 2.7175 - val_mse: 2.7174\n",
            "Epoch 41/50\n",
            "1/1 [==============================] - 0s 25ms/step - loss: 0.9845 - mse: 0.9844 - val_loss: 2.4590 - val_mse: 2.4590\n",
            "Epoch 42/50\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 0.9476 - mse: 0.9476 - val_loss: 2.2506 - val_mse: 2.2506\n",
            "Epoch 43/50\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 0.9636 - mse: 0.9636 - val_loss: 2.0908 - val_mse: 2.0907\n",
            "Epoch 44/50\n",
            "1/1 [==============================] - 0s 23ms/step - loss: 1.0150 - mse: 1.0150 - val_loss: 1.9750 - val_mse: 1.9750\n",
            "Epoch 45/50\n",
            "1/1 [==============================] - 0s 28ms/step - loss: 1.0799 - mse: 1.0798 - val_loss: 1.8969 - val_mse: 1.8968\n",
            "Epoch 46/50\n",
            "1/1 [==============================] - 0s 21ms/step - loss: 1.1365 - mse: 1.1364 - val_loss: 1.8492 - val_mse: 1.8491\n",
            "Epoch 47/50\n",
            "1/1 [==============================] - 0s 27ms/step - loss: 1.1671 - mse: 1.1671 - val_loss: 1.8258 - val_mse: 1.8258\n",
            "Epoch 48/50\n",
            "1/1 [==============================] - 0s 33ms/step - loss: 1.1620 - mse: 1.1619 - val_loss: 1.8227 - val_mse: 1.8226\n",
            "Epoch 49/50\n",
            "1/1 [==============================] - 0s 27ms/step - loss: 1.1193 - mse: 1.1193 - val_loss: 1.8376 - val_mse: 1.8375\n",
            "Epoch 50/50\n",
            "1/1 [==============================] - 0s 22ms/step - loss: 1.0445 - mse: 1.0445 - val_loss: 1.8695 - val_mse: 1.8695\n",
            "test RMSE 1.2655828696691496\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}